# =============================================================================
# =============================================================================
# # Create the dataset of treatment and cohorts
# =============================================================================
# =============================================================================

# =============================================================================
# Prepare Geographic layer of analysis
# =============================================================================

if ts['geog'] == 'zip':
    geog = 'zip'
    # =============================================================================
    # ZIP Distances

    #I calculated zip code centroids as the weighted average of tabulation blocks with population weights from the 2010 census.
    #   This results in centers of mass that are closer to the city center and much more relevant for large zip codes (particularly for suburbs)
    #   While there are some zip code differences between the zcta and zip code map it is generally under 1M people or 0.3%
    zs = pd.read_parquet(cdd['p_d_acs_zip'] + r'\acs_2014_2022_s3.parquet')
    zs = zs[zs.year==2020][['zip','state','population']]
    zp = pd.read_excel(cdd['p_d_geo_demo'] + r'\zip_pop.xlsx').rename(columns={'wlat':'clat','wlong':'clong','ZCTA5CE10':'zip','ALAND10':'sqkm'})[['zip','sqkm','long','lat','clong','clat']]
    zp['state'] = zp['zip'].map(zs.set_index('zip')['state'])
    df_geo = zp
    

if ts['geog'] == 'cbg':
    geog = 'cbg'
    # =============================================================================
    # CBG Distances
    geo_year=2019
    state_codes = {
        'WA': '53', 'DE': '10', 'DC': '11', 'WI': '55', 'WV': '54', 'HI': '15',
        'FL': '12', 'WY': '56', 'PR': '72', 'NJ': '34', 'NM': '35', 'TX': '48',
        'LA': '22', 'NC': '37', 'ND': '38', 'NE': '31', 'TN': '47', 'NY': '36',
        'PA': '42', 'AK': '02', 'NV': '32', 'NH': '33', 'VA': '51', 'CO': '08',
        'CA': '06', 'AL': '01', 'AR': '05', 'VT': '50', 'IL': '17', 'GA': '13',
        'IN': '18', 'IA': '19', 'MA': '25', 'AZ': '04', 'ID': '16', 'CT': '09',
        'ME': '23', 'MD': '24', 'OK': '40', 'OH': '39', 'UT': '49', 'MO': '29',
        'MN': '27', 'MI': '26', 'RI': '44', 'KS': '20', 'MT': '30', 'MS': '28',
        'SC': '45', 'KY': '21', 'OR': '41', 'SD': '46', "AS": "60", "MP": "69",
        "VI": "78", "GU": "66"
    }
    
    ##The user will need to run this at least once but afterwards they can comment it out for speed.
    ##Import all of the shapefiles.
    # cbg = [gpd.read_file(f) for f in glob.glob(cdd['geo_t'] + r'/BG_'+str(geo_year)+'/*.zip')]
    # cbg = pd.concat(cbg,ignore_index=True,sort=False)
    
    # #For each cbg find the centroid
    # cbg['clat']=cbg.geometry.centroid.y
    # cbg['clong']=cbg.geometry.centroid.x
    
    # #Generate a column with the cbg identifier (as numeric)
    # cbg['cbg']=pd.to_numeric(cbg['GEOID'])
    # cbg['fips']=np.floor(cbg['cbg']/10000000)
    # cbg[['fips','cbg','clat','clong']].to_parquet(cdd['p_d_geo'] + r'\cbg_geo'+str(geo_year)+'.parquet')
    cbg = pd.read_parquet(cdd['p_d_geo'] + r'\cbg_geo'+str(geo_year)+'.parquet')
    df_geo = cbg
    


# =============================================================================
# Find matches on a quarterly basis.
# =============================================================================
#Import the plasma center dataset.
pcp=pd.read_parquet(cdd['p_d_pcp'] + r'\pcp.parquet')


#======================================================================
#Use a treatment window of 4 years around an opening and a radius of treatment of 25km
window = ts['window']
radius = ts['radius']
#   That is, any opening that changes treatement incrementally by less than 0km is not considered a material treatment event.
#   However, only impose the threshold for a 10 kilometer radius
cti = ts['cti']
threshold = ts['threshold']


#Frequency of analysis
freq = ts['freq']
dates=pd.DataFrame(pd.date_range('2000-01-01', '2030-12-31',name='date',freq=str.upper(freq)))
dates['key']=0



#Make a list of all of the treatment cohorts
def treated_cohorts(target,radius,candidates):
    """Return a datafrom of cbgs within distance of the target as well as the target's open and close dates, 
        and the mfeiu which is a unique cohort identifier."""
    #Each degree of lat or long is roughly 40 miles or 80 km so to get a very rough search grid around a point of radius r (in km) I need the degree difference to be less than r/100
    degdist = 2*radius/100000
    #Get a very preliminary list of candidates (filter by the degree distance)
    temp=candidates[(abs(candidates.clat-target['clat'])<degdist)&(abs(candidates.clong-target['clong'])<degdist)].copy()
    #Calculate the actual distance using the vincenty distance formula (sort by distance)
    temp['distance']=temp.apply(lambda x: vinc.vincenty_inverse((target['clat'],target['clong']),(x['clat'],x['clong'])).m,axis=1)
    #temp['distance']=temp.apply(lambda x: distance.distance((target['clat'],target['clong']),(x['clat'],x['clong'])).m,axis=1)
    temp=temp[temp['distance']<radius].sort_values(by='distance').reset_index(drop=True)
    temp['key']=0
    temp=pd.merge(temp,dates,on='key')
    #set up some simple descriptive variables
    temp['treated']=1
    temp['cohort']=target['mfeiu']
    temp['open']=target['open']
    temp['gdate_p']=target['gdate_p']
    temp['gdate_np']=target['gdate_np']
    temp['close']=target['close']
    temp['open_qpost']=dates[dates['date']>=target['open']]['date'].min()
    temp['open_qpre']=dates[dates['date']<target['open']]['date'].max()
    temp['pclat']=target['clat']
    temp['pclong']=target['clong']
    temp['clean_gap']=target['gap']
    temp['clean_open']=target['clean_open']
    #temp['clean_opensv']=target['clean_opensv']
    # Drop observations after the close date
    temp=temp[temp.date<=dt.datetime(target['close'].year + window,target['close'].month,1)]
    return temp

targets = pcp.to_dict(orient='records')
treated_coh=Parallel(n_jobs=30, prefer="threads")(delayed(treated_cohorts)(target=t,radius=radius,candidates=df_geo) for t in tqdm(targets))
treated_coh=pd.concat(treated_coh,ignore_index=True,sort=False)
treated_coh.to_parquet(cdd['p_d_geo_tcp'] + r'\treated_coh_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')




# =============================================================================
# Generate geo treatment intensity at each point in time
# =============================================================================
#Make a time series for each geo-cohort that will allow me to track exposure over time.
treated_coh = pd.read_parquet(cdd['p_d_geo_tcp'] + r'\treated_coh_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')
#I have two definitions of treatment intensity which are conservative in different ways to be used for different analysis.
#   For the primary DiD analysis I use the expected opening date to calculate exposure, where expected open is the midpoint of gdate_p and gdate_np.
suf = '' #for DiD studies
treat = treated_coh[[geog, 'open','close','distance','cohort','clean_gap','clean_open','gdate_p','gdate_np']].drop_duplicates().reset_index(drop=True).sort_values(by=[geog,'open','distance'])
treat['key']=0
treat = pd.merge(treat,dates,how='left',on='key')
#Mark cbg-cohort-dates when the plasma center is actually there (accounts for closures)
treat['present']=((treat['date']>=treat['open'])&(treat['close']>=treat['date']))*1
treat['any_pc']=(treat.groupby([geog,'date'])['present'].transform('sum')>0)*1
#Get the highest treatment intensity among the present
treat['intensity'] = treat.groupby([geog,'date','present'])['distance'].transform('min')
treat['intensity'] = np.where(treat['any_pc'],treat['intensity'],np.nan)
# treat['intensity_2nd'] = np.where(treat['intensity']!=treat['distance'],treat['distance'],np.nan)
# treat['intensity_2nd'] = treat.groupby([geog,'date','present'])['intensity_2nd'].transform('min')
# treat['intensity_2nd'] = np.where(treat.groupby([geog,'date'])['present'].transform('sum')>1,treat['intensity_2nd'],np.nan)
# treat['intensity_3rd'] = np.where(treat[['intensity','intensity_2nd']].max(axis=1) < treat['distance'], treat['distance'], np.nan)
# treat['intensity_3rd'] = treat.groupby([geog,'date','present'])['intensity_3rd'].transform('min')
# treat['intensity_3rd'] = np.where(treat.groupby([geog,'date'])['present'].transform('sum')>1,treat['intensity_3rd'],np.nan)
treat['closest_cohort']=np.where((treat['present'])&(treat['intensity']==treat['distance']),treat['cohort'],np.nan)
treat['closest_cohort']=treat.groupby([geog,'date','present'])['closest_cohort'].transform('max')
# treat['closest_cohort_2nd']=np.where((treat['present'])&(treat['intensity_2nd']==treat['distance']),treat['cohort'],np.nan)
# treat['closest_cohort_2nd']=treat.groupby([geog,'date','present'])['closest_cohort_2nd'].transform('max')
# treat['closest_cohort_3rd']=np.where((treat['present'])&(treat['intensity_3rd']==treat['distance']),treat['cohort'],np.nan)
# treat['closest_cohort_3rd']=treat.groupby([geog,'date','present'])['closest_cohort_3rd'].transform('max')
# a=treat.closest_cohort.value_counts().reset_index()

#Mark establshments that are present and within Xkm of the cohort's plasma center at each point in time.
treat['wi250']=((treat['distance']<=25000)&(treat['present']==1))*1
treat['wi150']=((treat['distance']<=15000)&(treat['present']==1))*1
treat['wi100']=((treat['distance']<=10000)&(treat['present']==1))*1
treat['wi75']=((treat['distance']<=7500)&(treat['present']==1))*1
treat['wi50']=((treat['distance']<=5000)&(treat['present']==1))*1
treat['wi25']=((treat['distance']<=2500)&(treat['present']==1))*1
treat['wi20']=((treat['distance']<=2000)&(treat['present']==1))*1
treat['wi15']=((treat['distance']<=1500)&(treat['present']==1))*1
treat['wi10']=((treat['distance']<=1000)&(treat['present']==1))*1
treat['wi5']=((treat['distance']<=500)&(treat['present']==1))*1
#Count the number of establishments within each radius of the cbg
treat['npc250']=treat.groupby([geog,'date','present'])['wi250'].transform('sum')
treat['npc150']=treat.groupby([geog,'date','present'])['wi150'].transform('sum')
treat['npc100']=treat.groupby([geog,'date','present'])['wi100'].transform('sum')
treat['npc75']=treat.groupby([geog,'date','present'])['wi75'].transform('sum')
treat['npc50']=treat.groupby([geog,'date','present'])['wi50'].transform('sum')
treat['npc25']=treat.groupby([geog,'date','present'])['wi25'].transform('sum')
treat['npc20']=treat.groupby([geog,'date','present'])['wi20'].transform('sum')
treat['npc15']=treat.groupby([geog,'date','present'])['wi15'].transform('sum')
treat['npc10']=treat.groupby([geog,'date','present'])['wi10'].transform('sum')
treat['npc5']=treat.groupby([geog,'date','present'])['wi5'].transform('sum')
#   Mark the cohort that has the highest "present" value for each geo-date and keep only those rows.
#   Keep only the marked rows and then drop duplicates on cbg-date.
treat['mark']=(treat['present']==treat.groupby([geog,'date'])['present'].transform('max'))*1
treat = treat[treat.mark==1].drop_duplicates(subset=[geog,'date'])
tvars = [geog, 'date', 'intensity', 'closest_cohort'] + [v for v in list(treat) if any(re.findall(r'npc',v))]
treat = treat[tvars]
#Calculate 1) the change in treatment intensity,
#          2) the largest absolute change in treatment intensity over the past 3 years (12 quarters) 
#          3) the most recent absolute and pct change in treatment
treat.sort_values(by=[geog,'date'],inplace=True)
treat['intensity_lag']=treat.groupby(geog)['intensity'].shift(1)
treat['intensity_delta']=np.where((np.isnan(treat['intensity']))&(np.isnan(treat['intensity_lag'])),
                                 [0]*treat.shape[0],(treat['intensity']-treat['intensity_lag']).fillna(cti+1))
treat['intensity_absdelta'] = np.abs(treat['intensity_delta']).replace(1,100000)
treat['intensity_maxabsdelta3y']=treat.groupby(geog)['intensity_absdelta'].rolling(window=11).max().reset_index()['intensity_absdelta']
treat['intensity_maxabsdelta3y']=treat.groupby(geog)['intensity_maxabsdelta3y'].shift(1).fillna(0)
treat['intensity_pctdelta']=(treat['intensity_absdelta'] / treat['intensity_lag']).fillna(0)

treat[(treat['date']>=dt.datetime(2000,1,1))&(treat['date']<=dt.datetime(2030,12,31))].to_parquet(cdd['p_d_geo_tcp'] + r'\\treat_'+str(int(radius/1000))+'k_'+geog+'_'+freq+suf+'.parquet')



#Mark observations from the treated set looking forward from each event if: 
#   0) open must be after 1998 because otherwise I can't be sure that the location wasn't there already and it must be a clean open
#   1) They are the first treatment cohort OR
#   2) The treatment intensity hasn't changed in the past 3 years OR
#   3) The change in treatment intensity is at least 3km and a 75% reduction in distance 
#           (together these imply that the previous plasma center was at least 4k from the geog and the new location is within 1k of the geog)
cohorts = treated_coh[[geog,'cohort','open','close','open_qpost','open_qpre','distance','pclat','pclong','clean_gap','clean_open']].drop_duplicates(subset=[geog,'cohort']).reset_index(drop=True)
cohorts.sort_values(by=[geog,'open','distance'],inplace=True)
cohorts['temp']=1 #Used to find the first treatment event
cohorts=pd.merge(cohorts,treat[[geog, 'date', 'intensity','intensity_delta','intensity_absdelta','intensity_maxabsdelta3y','intensity_pctdelta']],how='left',left_on=[geog,'open_qpost'],right_on=[geog,'date']).drop(columns=['date'])
cohorts['use'] = ((cohorts['open']>dt.datetime(1999,1,1))&(cohorts['clean_open']>0)&
                           ((cohorts.groupby(geog)['temp'].cumsum()==1)|
                            (cohorts['intensity']>threshold)|
                            (cohorts['intensity_maxabsdelta3y']<=cti)))*1
cohorts['use_lib'] = ((cohorts['open']>dt.datetime(1999,1,1))&(cohorts['clean_open']>0)&
                           ((cohorts['use']==1)|((cohorts['intensity_absdelta']>3000)&(cohorts['intensity_pctdelta']>0.75))))*1
cohorts.drop(columns=['temp'],inplace=True)

#Look forward for each geo and determine the viable post period as the minimum of: 
#   1) The last quarter before the first subsequent opening where there was an increase in exposure
#   2) 4 years after the first quarter of treatment.
#       (if there was no material change in exposure then there is no need to worry about conflating the two treatments)
#   NOTE1: The post period is allowed to vary within cohorts by cbg since not all cbgs are treated in the same way.
#               For a global cohort level through date use throughmax (i.e. for control cohorts)
cohorts['intensity_leaddelta']=((abs((cohorts['intensity']-cohorts.groupby([geog])['intensity'].shift(-1)).fillna(0))>cti)&(cohorts['intensity']<threshold))*1
cohorts['open_lead']=cohorts.groupby(geog)['open_qpre'].shift(-1)
cohorts['through'] = cohorts.apply(lambda x: x['open_lead'] if x['intensity_leaddelta']==1 else np.nan,axis=1)
cohorts['through'] = cohorts.groupby(geog)['through'].transform(lambda x: x.bfill()).fillna(dt.datetime(2025,1,1))
cohorts['temp'] = pd.to_datetime({'year':cohorts['open_qpost'].dt.year+4,'month':cohorts['open_qpost'].dt.month,'day':1})
cohorts['through']=cohorts[['through','temp']].min(axis=1)
cohorts['throughmax']=cohorts.groupby(['cohort','open'])['through'].transform('max')
cohorts.drop(columns=['intensity_leaddelta','open_lead','temp'],inplace=True)
cohorts.to_parquet(cdd['p_d_geo_tcp'] + r'\cohorts_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')

#Merge these refinements to treatment cohorts into the treated_coh dataset and enforce the restrictions.
treated_cohs = pd.merge(treated_coh[[geog, 'date', 'cohort', 'open', 'treated', 'distance']],cohorts[[geog,'cohort','through','use_lib','use','open_qpre']],how='left',on=[geog,'cohort'])
treated_cohs = treated_cohs[(treated_cohs['use_lib']==1)&(treated_cohs['date']<=treated_cohs['through'])&(treated_cohs['date']>=pd.to_datetime({'year':treated_cohs['open'].dt.year-window,'month':treated_cohs['open'].dt.month,'day':1}))]
treated_cohs.info(memory_usage='deep')
treated_cohs.to_parquet(cdd['p_d_geo_tcp'] + r'\treated_cohs_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')



#Use future opening cohorts as controls for the current cohort
#   1) Require that control cbgs will be treated at some point in the future (not requiring a treatment intensity change) and 
#           that this future opening is clean (to avoid observations where the cbg acts as control but is actually being treated) BUT
#       2) will not experience a treatment intensity increase for at least 2 year after this opening (post buffer) AND
#       3) has not experienced a treatment intensity increase for the past 3 years. (pre buffer)
#       4) require that there is some future treatment is within the next 7 years.
#   Keep 4 years of control data before and after the current cohort open date (to keep the size down)
def cohort_controls(target, cohorts, treated, candidates, treatment, dfp_geo, threshold=10000, cti=0):
    #For each geog find the closest treatments in the future (proximate and distal) and past that are above the cti threshold.
    open_past = cohorts[(cohorts.open < target['open'])][[geog,'open']].rename(columns={'open':'open_past'}).drop_duplicates()
    open_future = cohorts[(cohorts.open > target['open'])][[geog,'open']].rename(columns={'open':'open_future'}).drop_duplicates()
    open_recent = cohorts[(((cohorts['intensity_absdelta']>cti)&(cohorts['intensity']<=threshold))|(cohorts['distance']<=threshold))&(cohorts.open < target['open'])].groupby(geog)['open'].max().reset_index().rename(columns={'open':'open_recent'})
    open_prox = cohorts[(((cohorts['intensity_absdelta']>cti)&(cohorts['intensity']<=threshold))|(cohorts['distance']<=threshold))&(cohorts.open >= target['open'])].groupby(geog)['open'].min().reset_index().rename(columns={'open':'open_prox'})
    candidate_geo = pd.concat([open_future,open_past],ignore_index=True,sort=False)
    candidate_geo['open_past'] = candidate_geo['open_past'].fillna(dt.datetime(1990,1,1))
    candidate_geo['open_future'] = candidate_geo['open_future'].fillna(dt.datetime(2050,1,1))
    candidate_geo = pd.merge(candidate_geo,open_recent,how='left',on=geog).fillna(dt.datetime(1990,1,1))
    candidate_geo = pd.merge(candidate_geo,open_prox,how='left',on=geog).fillna(dt.datetime(2025,1,1))
    candidate_geo['open_date'] = pd.to_datetime(np.where(candidate_geo['open_future']<dt.datetime(2050,1,1),candidate_geo['open_future'],candidate_geo['open_past']))
    candidate_geo['open_type'] = np.where(candidate_geo['open_future']<dt.datetime(2050,1,1),'future','past')
    candidate_geo = pd.merge(candidate_geo, cohorts, how='left',left_on=[geog,'open_date'],right_on=[geog,'open'])
    
    #restrict the list of candidate cbgs to those that satisfy 1-4
    candidate_geo['clean_open_w'] = ((candidate_geo.clean_open==1)|(candidate_geo.open<dt.datetime(2014,1,1))|
                                 ((candidate_geo.clean_open==0)&
                                 (target['open'] < candidate_geo.open - candidate_geo['clean_gap'].fillna(0) * dt.timedelta(days=365))&
                                 (target['open'] + dt.timedelta(days=365*3) < candidate_geo.close)))*1
    
    candidate_geo['ctype'] = ((candidate_geo['clean_open_w']==1)&(candidate_geo.open_type=='future')&
                                ((candidate_geo['close'] > target['open'] + dt.timedelta(days=365*2))&
                                 (candidate_geo.open_recent < target['open'] - dt.timedelta(days=365*4))&
                                 (candidate_geo.open_prox < target['open'] + dt.timedelta(days=365*10))&
                                ((candidate_geo.open_prox > target['open'] + dt.timedelta(days=365*2))|
                                ((candidate_geo.open_prox > target['open'] + dt.timedelta(days=365))&(target['open'].year==2019))|
                                ((candidate_geo.open_prox > target['open'] + dt.timedelta(days=190))&(target['open'].year>=2020)))))*1  
    candidate_geo['ctype'] = candidate_geo['ctype'] + ((candidate_geo.open_type=='past')&#(candidate_geo.clean_open_w==1)&
                                (candidate_geo['close'] > target['open'] + dt.timedelta(days=365*2))&
                                (candidate_geo.open_recent < target['open'] - dt.timedelta(days = 365*4))&
                                (candidate_geo.open_past > target['open'] - dt.timedelta(days = 365*30))&
                                ((candidate_geo.open_prox > target['open'] + dt.timedelta(days=365*2))|
                                ((candidate_geo.open_prox > target['open'] + dt.timedelta(days=365))&(target['open'].year==2019))|
                                ((candidate_geo.open_prox > target['open'] + dt.timedelta(days=190))&(target['open'].year>=2020))))*2
    cand_vars_keep = [geog, 'open_recent', 'open_prox', 'cohort', 'open', 'close', 'distance', 'pclat', 'pclong',
                            'intensity', 'intensity_delta', 'intensity_absdelta', 'throughmax', 'ctype']

    #Drop closed establishments from the running 
    candidate_geo = candidate_geo[candidate_geo.close > target['open'] - dt.timedelta(days=365*4)]
    #For each geog keep 1) the closest viable counterfactual to the geog by ctype
    #                   2) the most impactful plasma center opening (largest change in treatment intensity) by ctype
    #                   3) the first plasma center treatment (if the opening was after 2003)
    #Require that it has a core of influence.
    candidate_geo['intensity_absdelta2'] = np.where((candidate_geo['intensity_absdelta']>0)&(candidate_geo['intensity_absdelta']<100000),candidate_geo['intensity_absdelta'],np.nan)
    candidate_geo['mark'] = np.where(((candidate_geo.intensity_absdelta==100000)&(candidate_geo['open']>dt.datetime(2003,1,1)))|
                                      ((candidate_geo['ctype']>0)&(candidate_geo.distance==candidate_geo.groupby([geog,'ctype'])['distance'].transform('min')))|
                                      ((candidate_geo['ctype']>0)&(candidate_geo.intensity_absdelta2==candidate_geo.groupby([geog,'ctype'])['intensity_absdelta2'].transform('max'))), 1, np.nan)
    candidate_geo['mark2'] = np.where(((candidate_geo['ctype']>0)&(candidate_geo['mark']>=0)&(candidate_geo['distance']<threshold/2)),1,np.nan)
    candidate_geo['mark2'] = candidate_geo.groupby('cohort')['mark2'].transform('max')
    candidate_geo = candidate_geo[((candidate_geo.mark>=0)&(candidate_geo.mark2>=0))|(candidate_geo.groupby(geog)['ctype'].transform('max')==0)].drop_duplicates(subset=[geog,'ctype'])[cand_vars_keep]

    #Build the panel
    temp=pd.merge(candidates,candidate_geo,how='inner',on=[geog,'cohort'])
    #Add in local counterfactuals (adding in ALL potential counterfactual zip codes will be massive and make everything crash)
    local_cf  = dfp_geo[~dfp_geo[geog].isin(list(candidate_geo[geog].unique())+treated)]
    local_cfs = local_cf.drop_duplicates(subset=[geog])
    #Keep only counterfactuals within 1deg distance of the target plasma center (roughly 100km)
    local_cfs=local_cfs[(abs(local_cfs.pclat-target['pclat'])<1)&(abs(local_cfs.pclong-target['pclong'])<1)].copy()
    temp=pd.concat([temp,local_cf[local_cf[geog].isin(local_cfs[geog].unique())]],ignore_index=True,sort=False)
    #Subset to keep only local or valid counterfactual geogs
    temp=temp[(temp.cohort==0)|(temp.cohort!=target['cohort'])&(temp.ctype>0)]
    temp=temp[(~temp[geog].isin(treated))]
    
    #keep only observations before the proximate treatment and restrict to the treatment window.
    coh_open_post = target['open'] + dt.timedelta(days=365 * window + 120)
    coh_open_pre = target['open'] - dt.timedelta(days=365 * window + 120)
    temp=temp[((temp.date<temp.open_prox)|(temp.cohort==0))&(temp.date<=target['throughmax'])&(temp.date >= coh_open_pre)&(temp.date <= coh_open_post)]
    
    #Rename some variables.
    temp.rename(columns={'open':'open_cf','cohort':'cohort_cf','intensity':'intensity_cf'},inplace=True)
    temp['cohort']=target['cohort']
    temp['open']=target['open']
    temp.drop_duplicates(subset=[geog,'date','cohort_cf'],inplace=True)

    #Merge in the treatment intensity (directly before and after open as well as quarterly)
    temp = pd.merge(temp,treatment.rename(columns={'intensity':'intensity'}),how='left',on=[geog,'date'])
    #Double check that the geog code doesn't have a change in intensity 
    #   (this will occur due to closure which is rare and wasn't checked for above)
    temp.sort_values(by=['cohort_cf',geog,'date'],inplace=True)
    temp['check'] = pd.isnull(temp['intensity'])
    temp['check_change'] = (((temp.date>target['open']-dt.timedelta(days=365*3))|(temp.date<target['open']+dt.timedelta(days=365*2)))&
                                        ((((temp['intensity']<threshold)|(temp.groupby(['cohort_cf',geog])['intensity'].shift(1)<threshold))&
                                         ((temp.groupby(['cohort_cf',geog])['check'].shift(1)==1)|(temp['intensity']<temp.groupby(['cohort_cf',geog])['intensity'].shift(1))))|
                                        (((temp['intensity']<threshold)|(temp.groupby(['cohort_cf',geog])['intensity'].shift(-1)<threshold))&
                                         ((temp.groupby(['cohort_cf',geog])['check'].shift(-1)==1)|(temp['intensity']<temp.groupby(['cohort_cf',geog])['intensity'].shift(-1))))))*1
    temp=temp[(temp.cohort_cf==0)|(temp.groupby(['cohort_cf',geog])['check_change'].transform('max')==0)]


    #Require that at least one of the geog in the counterfactual cohort is within threshold and that a geog is treated by >cti (otherwise set cohort to 0 so that I can keep the close geogs as counterfactuals)
    temp['test'] = (temp.distance<threshold/2)&(temp.intensity_absdelta>threshold/2)
    temp['cohort_cf'] = np.where((temp.groupby('cohort_cf')['test'].transform('max')==0)&(temp['cohort_cf']>0),0,temp['cohort_cf'])
    temp = pd.merge(temp,dfp_geo[[geog, 'pclat', 'pclong']].drop_duplicates(subset=geog).rename(columns={'pclat':'clat','pclong':'clong'}),how='left',on=[geog])
    temp['pclat'] = np.where((temp.groupby('cohort_cf')['test'].transform('max')==1)&(temp['cohort_cf']>0),temp['pclat'],temp['clat'])
    temp['pclong'] = np.where((temp.groupby('cohort_cf')['test'].transform('max')==1)&(temp['cohort_cf']>0),temp['pclong'],temp['clong'])
    temp.drop(columns=['test','check','check_change','clat','clong'],inplace=True)
    #For parsimony require that counterfactuals pc openings are within 2000km of treatment cohort's plasma center.
    #   This allows for regional variation and that controls in a region (eg midwest) are better controls for treated regions
    #   However, I did it mostly to try and limit the size of the control dataset which was large at >140M rows.
    cont_dist = 2000000
    cont_degdist = 2*cont_dist/100000
    temp['cohcf_geog'] = np.where(temp.cohort_cf==0,1000000000000 + temp[geog],temp['cohort_cf'])
    temp2 = temp[['cohcf_geog','ctype','pclat','pclong']].drop_duplicates(subset=['cohcf_geog']).copy()
    temp2=temp2[(abs(temp2.pclat-target['pclat'])<cont_degdist)&(abs(temp2.pclong-target['pclong'])<cont_degdist)].copy()

   
    if temp.shape[0]>0:
        #Drop duplicates, calculate distances, and merge distances back
        try:
            temp2['distance_t2cf']=temp2.apply(lambda x: vinc.vincenty_inverse((x['pclat'],x['pclong']),(target['pclat'],target['pclong'])).m,axis=1)
        except Exception:
            temp2['distance_t2cf']=temp2.apply(lambda x: distance.distance((x['pclat'],x['pclong']),(target['pclat'],target['pclong'])).m,axis=1)
        temp2['ctype'] = np.where((temp2.cohcf_geog>1000000000000),0,temp2['ctype'])
        temp2.sort_values(by=['ctype','distance_t2cf'],ascending=[1,1],inplace=True)
        temp2['rank'] = 1
        temp2['rank'] = temp2.groupby('ctype')['rank'].cumsum()
        temp=pd.merge(temp,temp2[['cohcf_geog','distance_t2cf','rank']],how='left',on='cohcf_geog')
        temp = temp[((temp.cohort_cf==0)&(temp.distance_t2cf<100000))|
                ((temp.cohort_cf>0)&(temp.ctype==1)&(temp['rank']<=75))|
                ((temp.cohort_cf>0)&(temp.ctype==2)&(temp['rank']<=75))]
        temp['ctype'] = np.where(temp.cohort==0,0,temp.ctype)
        #Null out some variables if it is the "local" counterfactual.
        for v in ['distance', 'ctype']:
            temp[v] = np.where(temp['cohort_cf']==0,np.nan,temp[v])
        for v in ['open_cf', 'open_prox', 'open_recent']:
            temp[v] = pd.to_datetime(np.where(temp['cohort_cf']==0,pd.NaT,temp[v]))
        kvars = [geog, 'date', 'cohort', 'open', 'cohort_cf', 'open_cf', 'distance_t2cf', 'distance', 'ctype',
                         'open_prox','open_recent','pclat','pclong']
        
        temp[kvars].to_parquet(cdd['p_d_geo_tcp'] + r'\Cohorts\control_coh_'+str(int(target['cohort']))+'_'+freq+'.parquet')
        return {**target,**{'return_code':1}}
    else:
        return {**target,**{'return_code':0}}


#Import the intermediate datasets
treated_cohs = pd.read_parquet(cdd['p_d_geo_tcp'] + r'\treated_cohs_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')
treated_coh = pd.read_parquet(cdd['p_d_geo_tcp'] + r'\treated_coh_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')
treat = pd.read_parquet(cdd['p_d_geo_tcp'] + r'\\treat_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')
cohorts = pd.read_parquet(cdd['p_d_geo_tcp'] + r'\cohorts_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')


#Prepare a balanced panel dataset of geog and dates to fill in metro areas where there hasn't been an opening.
dfp_geo = bal_panel(index={geog:df_geo[geog].unique(),'date':dates.date.unique()})
dfp_geo['cohort'] = 0
dfp_geo = pd.merge(dfp_geo,df_geo[[geog,'clat','clong']],how='left',on=geog).rename(columns={'clat':'pclat','clong':'pclong'})


#Get a list of all times each zip has been treated such that the treatment intensity changed.
cohorts_s = cohorts[(cohorts.cohort>1000)&(cohorts.use_lib==1)].drop_duplicates(subset=['cohort','open'])
targetsg = cohorts_s.to_dict(orient='records')
# target=cohorts[cohorts.cohort==517858136221].to_dict(orient='records')[0]
cand_vars = [geog, 'date', 'cohort']
candidates = treated_coh[cand_vars].copy()
treat_vars = [geog,'date','intensity']
treatment=treat[treat_vars].copy()
treatment.info(memory_usage='deep')
treated=[]#list(treated_cohs[treated_cohs.cohort==83199857151][geog].unique())

#Import a list of files created.
coh = [int(re.findall(r'([\d]{1,})'+tfreq+'\.parquet',v)[0]) for v in os.listdir(cdd['p_d_geo_tcp'] + r'\Cohorts') if any(re.findall(r'([\d]{1,})_'+freq+'.parquet',v))]
targetsgr = [v for v in targetsg if (v['cohort'] not in coh)]



t=time.time()
temp = Parallel(n_jobs=10)(delayed(cohort_controls)(target=t,cohorts=cohorts.copy(),
                                                         treated=[],
                                                         candidates=candidates.copy(), treatment=treatment.copy(),
                                                         dfp_geo=dfp_geo.copy(), threshold=10000) for t in tqdm(targetsgr))
t=time.time()-t



#Import and concatenate the control cohorts
coh = [v for v in glob.glob(cdd['p_d_geo_tcp'] + r'\Cohorts\*.parquet') if any(re.findall(r'([\d]{1,})_'+freq+'.parquet',v))]
temp = Parallel(n_jobs=10)(delayed(pd.read_parquet)(f) for f in tqdm(coh))
control_coh = pd.concat(temp, ignore_index=True,sort=False)
control_coh.to_parquet(cdd['p_d_geo_tcp'] + r'\control_coh_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')
control_coh.info(memory_usage='deep')


#Concatenate the treated and control regions panel
control_coh = pd.read_parquet(cdd['p_d_geo_tcp'] + r'\control_coh_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')
control_coh['treated']=0
tp = pd.merge(treated_cohs,pcp[['mfeiu','clat','clong']].rename(columns={'mfeiu':'cohort','clat':'pclat','clong':'pclong'}),how='left',on=['cohort']).drop(columns=['through','open_qpre'])
tcp = pd.concat([tp,control_coh],ignore_index=True,sort=False)
tcp=tcp[(tcp.date<=dt.datetime(2021,12,31))]


#Create the variable containing event-time
temp_etime=tcp.drop_duplicates(subset=['date','open'])[['date','open']]
if freq=='m':
    temp_etime['etime']=temp_etime.apply(lambda x: np.ceil((12*(x['date'].year-x['open'].year)+(x['date'].month-x['open'].month))), axis=1).replace(-0.0,0)
if freq=='q':
    temp_etime['etime']=temp_etime.apply(lambda x: np.ceil((12*(x['date'].year-x['open'].year)+(x['date'].month-x['open'].month))/3) - 1, axis=1).replace(-0.0,0) + (temp_etime['open'].dt.month.isin([3,6,9,12]))*1
temp_etime['merge_date'] = pd.to_datetime(np.where(temp_etime.etime.isin([0,-1]),temp_etime['date'],pd.NaT))
temp_etime['merge_date'] = pd.to_datetime(np.where(temp_etime.etime>=0,temp_etime.groupby('open')['merge_date'].transform('max'),temp_etime.groupby('open')['merge_date'].transform('min')))
temp_etime['open_qpre'] = pd.to_datetime(np.where(temp_etime.etime == -1,temp_etime['date'],pd.NaT))
temp_etime['open_qpre'] = temp_etime.groupby('open')['open_qpre'].transform('max')
tcp = pd.merge(tcp,temp_etime,how='left',on=['date','open'])


#Keep only observations for built PCs, within the event window
fm = {'M':12,'Q':4}
nper = fm[str.upper(freq)]
tcp=tcp[(tcp.cohort>10000)&(abs(tcp.etime)<=nper*window)]


#Add the date distance between open and open_cf
tcp['date_del'] = (tcp['open_cf'] - tcp['open']).dt.days / 365

#Sometimes the region around an opening is included as a control (ctype is missing) when it should instead be treated.
tcp = tcp[~((pd.isnull(tcp['ctype'])&(tcp['distance_t2cf']<threshold)&(tcp.groupby(['cohort','date'])[geog].transform('count')>1)))]

#I did not prevent a geog code from being included in both treated if there was another plasma center that was closer than the treetment.
#   Consider control to take priority since this means we are at least threshold away from the opening (ie. it is low intensity treatment if anything).
tcp = tcp[(tcp.treated==tcp.groupby(['cohort',geog])['treated'].transform('min'))]

#Set ctype to 0 for the local counterfactuals.
tcp['ctype'] = np.where((tcp.treated==0)&(tcp.cohort_cf==0),0,tcp.ctype)
tcp['ctype'] = tcp['ctype'].fillna(5)


#Subset to only keep observations within 10km of the treatment cohort
tcp = tcp[((tcp.distance<10000)|((tcp.ctype==0)&(tcp.distance_t2cf<10000)))]

#Add a column with a random number for each cohort-cohort_cf group.
import random as rn
rn.seed(10)
#tcp = pd.read_pickle(cd_geog + r'\\'+geog+'_tcp.pkl')
tcp['randomint'] = tcp.groupby(['cohort','cohort_cf'],dropna=False).ngroup()
rnmap = dict(zip(tcp['randomint'].unique(),rn.sample(range(1, 1000000), len(tcp['randomint'].unique()))))
tcp['randomint'] = np.where(tcp.treated==1,0,tcp['randomint'].map(rnmap))
tcp.sort_values(by=['treated','cohort','randomint'],ascending=[1,1,1],inplace=True)
tcp['randomrank'] = 1 - tcp.duplicated(subset=['treated','cohort','randomint'])
tcp['randomrank'] = tcp.groupby(['treated','cohort'])['randomrank'].cumsum()
tcp['randomrank'] = np.where(tcp['treated']==1,0,tcp['randomrank'])
tcp['temp_t2cf'] = np.where((tcp.treated==0)&(tcp.cohort_cf>0),tcp['distance_t2cf'],np.nan)
tcp.sort_values(by=['treated','cohort','temp_t2cf'],ascending=[1,1,1],inplace=True)
tcp['distance_t2cf_rank'] = 1 - tcp.duplicated(subset=['treated','cohort','temp_t2cf'])
tcp['distance_t2cf_rank'] = tcp.groupby(['treated','cohort'])['distance_t2cf_rank'].cumsum()
tcp['distance_t2cf_rank'] = np.where(tcp['treated']==1,0,tcp['distance_t2cf_rank'])
tcp.drop(columns=['temp_t2cf'],inplace=True)



#Merge in the change in treatment intensity at treatment or at counterfactual treatment
tcp.loc[tcp['open_cf'].dt.year>2022,'open_cf'] = dt.datetime(2021,12,31)
tcp['open_cf'] = tcp['open_cf'].fillna(tcp['open'])
openqp = {d:dates[dates['date']>d]['date'].min() for d in tcp['open_cf'].unique()}
tcp['open_cf_qpost'] = tcp['open_cf'].map(openqp)
tcp['cohort_cf'] = tcp['cohort_cf'].fillna(tcp['cohort'])

#Merge in the change in treatment intensity that occurs for treated cohorts or will (did) occur for future (past) counterfactuals.
tcp = pd.merge(tcp,treat[['date',geog,'intensity_absdelta','intensity']].rename(columns={'date':'open_cf_qpost','intensity_absdelta':'intensity_absdelta_tcf','intensity':'intensity_tcf'}),how='left',on=['open_cf_qpost',geog])
tcp.loc[tcp['open_cf'].dt.year<=2000,'intensity_absdelta_tcf'] = 100000
tcp['intensity_tcf'] = np.where(tcp['open_cf'].dt.year<=2000,tcp['intensity_tcf'].fillna(tcp['distance']),tcp['intensity_tcf'])
tcp['intensity_absdelta_tcf'] = np.where((tcp['open_cf'].dt.year>2020), np.where(pd.isnull(tcp['intensity_tcf']), 100000, np.abs(tcp['distance'] - tcp['intensity_tcf'])),tcp['intensity_absdelta_tcf'])
tcp['intensity_tcf'] = np.where(tcp['open_cf'].dt.year>2020,tcp['intensity_tcf'].fillna(tcp['distance']),tcp['intensity_tcf'])
tcp.drop(columns=['intensity_tcf'],inplace=True)

#Require that each cohort has at least one treated geography and one control geography.
tcp = tcp[(tcp.groupby('cohort')['treated'].transform('max')!=tcp.groupby('cohort')['treated'].transform('min'))]

#Rank the control geographies
tcp['temp_t2cf'] = np.where((tcp.treated==0)&(tcp.cohort_cf>0),tcp['distance_t2cf'],np.nan)
tcp.sort_values(by=['treated','cohort','temp_t2cf'],ascending=[1,1,1],inplace=True)
tcp['distance_t2cf_rank'] = 1 - tcp.duplicated(subset=['treated','cohort','temp_t2cf'])
tcp['distance_t2cf_rank'] = tcp.groupby(['treated','cohort'])['distance_t2cf_rank'].cumsum()
tcp['distance_t2cf_rank'] = np.where(tcp['treated']==1,0,tcp['distance_t2cf_rank'])
tcp.drop(columns=['temp_t2cf'],inplace=True)


#Save out this skeleton treatment dataset (skeleton because it doesn't yet have the treatment intensity)
tcp.to_parquet(cdd['p_d_geo_tcp'] + r'\\tcp_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')
tcp = pd.read_parquet(cdd['p_d_geo_tcp'] + r'\\tcp_'+str(int(radius/1000))+'k_'+geog+'_'+freq+'.parquet')
tcp.info(memory_usage='deep')

